#pragma once
#include "DVector.hpp"
//  dynamic matrix, size is changable, and can own the data or not(different from Array<2, T>)
namespace zzz{
template <typename T, bool OWN=true>
class DMatrixBase
{
public:
  // type definitions
  typedef T              value_type;
  typedef T*             iterator;
  typedef const T*       const_iterator;
  typedef T&             reference;
  typedef const T&       const_reference;
  typedef zsize          size_type;
  typedef size_type      difference_type;

protected:
  size_type nrow,ncol,size;
  size_type pit;
  T *v;

public:
  inline DVector<T,false> Row(size_type a)
  {
    return DVector<T,false>(v+a*pit,ncol);
  }
  inline const DVector<T,false> Row(size_type a) const
  {
    return DVector<T,false>(v+a*pit,ncol);
  }
  template<bool OWN1>
  inline void operator=(const DMatrixBase<T,OWN1>& a);
  template<bool OWN1>
  inline DMatrixBase<T,true> operator*(const DMatrixBase<T,OWN1>& a)
  {
    DMatrixBase<T,true> temp(nrow, a.ncol);
    for (size_type i=0; i<nrow; i++) for (size_type j=0; j<a.ncol; j++)
    {
      temp[i][j]=0.0f;
      for (size_type k=0; k<ncol; k++) temp[i][j]+=At(i, k)*a[k][j];
    }
    return temp;
  }
  template<bool OWN1>
  inline DVector<T,true> operator*(const DVector<T,OWN1>& a)
  {
    ZCHECK_EQ(ncol, a.size);
    DVector<T,true> ret(nrow);
    for (size_type i=0; i<nrow; i++) ret[i]=Row(i).Dot(a);
    return ret;
  }
  inline void operator=(const T& a)
  {
    for (size_type i=0; i<nrow; i++) for (size_type j=0; j<ncol; j++) At(i, j)=a;
  }

  inline void operator+=(const T a)
  {
    for (size_type i=0; i<nrow; i++) for (size_type j=0; j<ncol; j++) At(i, j)+=a;
  }
  inline void operator-=(const T a)
  {
    for (size_type i=0; i<nrow; i++) for (size_type j=0; j<ncol; j++) At(i, j)-=a;
  }
  inline void operator*=(const T a)
  {
    for (size_type i=0; i<nrow; i++) for (size_type j=0; j<ncol; j++) At(i, j)*=a;
  }
  inline void operator/=(const T a)
  {
    for (size_type i=0; i<nrow; i++) for (size_type j=0; j<ncol; j++) At(i, j)/=a;
  }
public:
  friend inline ostream& operator<<(ostream& os,const DMatrixBase<T> &me)
  {
    for (size_type i=0; i<me.nrow; i++)
    {
      for (size_type j=0; j<me.ncol; j++) zout<<me.v[i*me.pit+j]<<' ';
      zout<<'\n';
    }
  }
  inline T& At(size_type r,size_type c){return v[r*pit+c];}
  inline const T& At(size_type r,size_type c) const {return v[r*pit+c];}
  inline T& operator()(size_type r,size_type c){return v[r*pit+c];}
  inline const T& operator()(size_type r,size_type c) const {return v[r*pit+c];}
  void Negative() {
    for (size_type i=0; i<nrow; i++) for (size_type j=0; j<ncol; j++) 
      At(i, j)=-At(i, j);
  }
  void SetSize(size_type r,size_type c);
  void Transpose();
  void Zero();
  void Fill(const T &v);
public:
  //math
  inline T Sum()
  {
    T ret=0;
    for (size_type i=0; i<nrow; i++) for (size_type j=0; j<ncol; j++)
      ret+=At(i, j);
    return ret;
  }

  DMatrixBase():v(NULL),nrow(0),ncol(0),pit(0),size(0){}

};



//////////////////////////////////////////////////////////////////////////
template<typename T, bool OWN=true>
class DMatrix : public DMatrixBase<T,OWN>
{
// Need to be partial specified.
};

//specialize whole class for own
template<typename T>
class DMatrix<T,true> : public DMatrixBase<T,true>
{
public:
  using DMatrixBase<T,true>::size;
  using DMatrixBase<T,true>::v;
  using DMatrixBase<T,true>::nrow;
  using DMatrixBase<T,true>::ncol;
  using DMatrixBase<T,true>::pit;
  DMatrix(){}
  DMatrix(size_type row, size_type col){SetSize(row,col);}
  DMatrix(size_type row, size_type col, T f){
    SetSize(row, col);for (size_type i=0; i<size; i++) v[i]=f;}
  DMatrix(T *a, size_type row, size_type col) {
    v=new T[row * col];
    memcpy(v, a, sizeof(T) * row * col);
    nrow = row; ncol = col; pit = col;
  }
  DMatrix(const T *a,size_type row,size_type col)
  {
    v=new T[row*col];
    memcpy(v, a,sizeof(T)*row*col);
    nrow=row;ncol=col;pit=col;
  }
  DMatrix(T *a,zuint row,zuint col,zuint apit)
  {
    v=new T[row*col];
    for(zuint i=0; i<row; i++) memcpy(v+i*col, a+i*apit,sizeof(T)*col);
    pit=col;
    nrow=row;ncol=col;
  }
  //copy constructor
  DMatrix(const DMatrix<T,true> &other){*this=other;}
  template<bool OWN1> 
  DMatrix(const DMatrixBase<T,OWN1>& a){*this=a;}
  //operator=
  const DMatrix<T,true> &operator=(const DMatrix<T,true> &a)
  {
    SetSize(a.nrow, a.ncol);
    for (zuint i=0; i<nrow*ncol; i++) v[i]=a.v[i];
    return *this;
  }
  template<bool OWN1> 
  const DMatrixBase<T,true> &operator=(const DMatrixBase<T,OWN1>& a)
  {
    SetSize(a.nrow, a.ncol);
    for (zuint i=0; i<nrow*ncol; i++) v[i]=a.v[i];
    return *this;
  }

  ~DMatrix()
  {
    if (v) delete v;
  }
  void SetSize(zuint r,zuint c)
  {
    if (r==nrow && c==ncol) return;
    if (r*c!=size)
    {
      if (v) delete v;
      v=new T[r*c];
    }
    nrow=r;
    ncol=c;
    size=r*c;
    pit=c;
  }
  using DMatrixBase<T,true>::At;
  void Transpose()
  {
    DMatrix<T,true> temp(*this);
    SetSize(ncol,nrow);
    for (zuint i=0; i<nrow; i++) for (zuint j=0; j<ncol; j++)
      At(i, j)=temp.At(j, i);
  }
  inline T* Data() {return v;}
  inline const T* Data() const {return v;}
  void Zero()
  {
    memset(v, 0,sizeof(T)*nrow*ncol);
  }
  void Fill(const T &a)
  {
    zuint s=nrow*ncol;
    for (zuint i=0; i<s; i++) v[i]=a;
  }
};

//specialize whole class for not own
template<typename T>
class DMatrix<T,false> : public DMatrixBase<T,false>
{
public:
  using DMatrixBase<T,false>::size;
  using DMatrixBase<T,false>::v;
  using DMatrixBase<T,false>::nrow;
  using DMatrixBase<T,false>::ncol;
  using DMatrixBase<T,false>::pit;
  using DMatrixBase<T,false>::At;
//  DMatrix(){}
  DMatrix(T *a,zuint m,zuint n)
  {
    v=a;
    nrow=m;ncol=n;pit=n;
  }
  DMatrix(const T *a, zuint m, zuint n);
  DMatrix(T *a, zuint m, zuint n, zuint apit)
  {
    v=a;
    pit=apit;
    nrow=m;ncol=n;
  }
  DMatrix(const DMatrix<T,false>& a)
  {
    v=a.v;
    nrow=a.nrow;ncol=a.ncol;pit=a.pit;
  }
  template<bool OWN1> DMatrix(const DMatrixBase<T,OWN1>& a)
  {
    v=a.v;
    nrow=a.nrow;ncol=a.ncol;pit=a.pit;
  }
  //////////////////////////////////////////////////////////////////////////
  const DMatrix<T,false>& operator=(const DMatrix<T,false>& a)
  {
    ZCHECK(nrow=a.nrow && ncol=a.ncol);
    for (zuint r=0; r<nrow; r++) for (zuint c=0; c<ncol; c++)
      At(r, c)=a.At(r, c);
    return *this;
  }
  template<bool OWN1> const DMatrixBase<T,false> & operator=(const DMatrixBase<T,OWN1>& a)
  {
    ZCHECK(nrow=a.nrow && ncol=a.ncol);
    for (zuint r=0; r<nrow; r++) for (zuint c=0; c<ncol; c++)
      At(r, c)=a.At(r, c);
    return *this;
  }
  template<bool OWN1> void SetSource(const DMatrixBase<T,OWN1>& a)
  {
    v=a.v;
    nrow=a.nrow;ncol=a.ncol;pit=a.pit;
  }
  template<bool OWN1> void SetSource(const T *a, const zuint rows, const zuint cols)
  {
    v=a;
    nrow=rows;ncol=cols;pit=ncol;
  }
  void Transpose()
  {
    DMatrix<T> temp(ncol,nrow);
    memcpy(temp.v, v,sizeof(T)*nrow*ncol);
    for (zuint i=0; i<nrow; i++) for (zuint j=0; j<ncol; j++)
      At(j, i)=temp.At(i, j);
  }
  void Zero()
  {
    T *c=v;
    for (zuint i=0; i<nrow; i++)
    {
      memset(c, 0,sizeof(T)*ncol);
      c+=pit;
    }
  }
  void Fill(const T &a)
  {
    zuint s=nrow*ncol;
    T *c=v;
    for (zuint i=0; i<nrow; i++)
    {
      for (zuint j=0; j<ncol; j++) c[j]=a;
      c+=pit;
    }
  }
};

typedef DMatrix<zint8> DMatrixi8;
typedef DMatrix<zuint8> DMatrixui8;
typedef DMatrix<zint16> DMatrixi16;
typedef DMatrix<zuint8> DMatrixui16;
typedef DMatrix<zint32> DMatrixi32;
typedef DMatrix<zuint32> DMatrixui32;
typedef DMatrix<zint64> DMatrixi64;
typedef DMatrix<zuint64> DMatrixui64;
typedef DMatrix<zfloat32> DMatrixf32;
typedef DMatrix<zfloat64> DMatrixf64;


typedef DMatrix<float> DMatrixf;
typedef DMatrix<double> DMatrixd;
typedef DMatrix<int> DMatrixi;
typedef DMatrix<zuint> DMatrixui;
}
